Tutorial for a gene expression error model with spike-ins, biological replicates, on simulated data¶

InĀ [1]:
import pymc as pm
import arviz as az
#import bebi103

import numpy as np
import pandas as pd
import iqplot

import bokeh.io
import bokeh.plotting
from bokeh.layouts import gridplot
bokeh.io.output_notebook()

import colorcet

from bokeh.plotting import figure, show, output_notebook
from bokeh.layouts import row, column
from bokeh.models import Span, ColumnDataSource
from bokeh.layouts import gridplot

import matplotlib.pyplot as plt

from tqdm import tqdm
Loading BokehJS ...

Model descriptions¶

Biological replicates¶

$$ \underset{\textcolor{purple}{\text{posterior}}}{\pi \left( \underline{\alpha}, \underline{b^{(x)}},\underline{b^{(s)}}, \mu_b, \sigma_b, \underline{x^{mRNA}}, X^{mRNA} | \underline{x^{seq}}, \underline{s^{seq}}, \underline{s^{spike}} \right)} \propto$$$$\underset{\textcolor{purple}{\text{spike-in likelihood}}}{\pi \left(\underline{s^{spike}}, \underline{s^{seq}} | \underline{b^{(s)}} \right)} \underset{\textcolor{purple}{\text{priors}}}{\pi \left(\underline{b^{(s)}} | \mu_b, \sigma_b \right) \pi \left(\mu_b \right) \pi \left(\sigma_b \right)}\times$$$$\underset{\textcolor{purple}{\text{txtome likelihood}}}{\pi \left(\underline{x^{seq}} | \underline{x^{mRNA}}, \underline{b^{(x)}} \right)} \underset{\textcolor{purple}{\text{priors}}}{\pi \left(\underline{b^{(x)}} | \mu_b, \sigma_b \right) }\times$$$$ \underset{\textcolor{purple}{\text{idealized likelihood}}}{\pi \left( \underline{x^{mRNA}} | X^{mRNA}, \underline{\alpha}\right)} \underset{\textcolor{purple}{\text{priors}}}{\pi \left(X^{mRNA} \right) \pi \left(\underline{\alpha} | \underline{\alpha_{global}} \right) \pi \left( \underline{\alpha_{global}}\right)}$$

Generate simulated data¶

InĀ [2]:
x_seq = np.array([[500, 1, 10, 9000, 750, 2, 36962],
                [450, 5, 5, 1010, 900, 5, 37000]])  # observed transcript counts


s_seq = np.array([[4, 14, 175, 2875, 25678],
                [5, 12, 140, 4000, 27000]])
#s_seq = np.array([0, 1, 100, 900, 12000])
s_spike = np.array([[2, 20, 200, 2000, 20000],
                  [2, 20, 200, 2000, 20000]])

R, G = x_seq.shape
K = s_seq.shape[1]

X_mrna_est = np.sum(x_seq, axis = 1)
betas = x_seq / X_mrna_est.reshape(-1,1)
betas_mean = betas.mean(axis=0)

Run model¶

InĀ [3]:
with pm.Model() as model:
    # spike-in/transcriptome priors
    mu_b = pm.HalfNormal("mu_b", sigma = 1)
    sigma_b = pm.LogNormal("sigma_b", mu = 0.5, sigma = 0.5)
    b_x = pm.LogNormal('b_x', mu = mu_b, sigma = sigma_b, shape = (R,G))
    b_s = pm.LogNormal("b_s", mu = mu_b, sigma = sigma_b, shape = (R,K)) # CHECK?

    # idealized priors
        # hyperprior global alpha
    alpha_global = pm.Dirichlet('alpha_global', a = betas_mean, shape = G)
    
    X_mrna = pm.Normal("X_mrna", mu = X_mrna_est, sigma = 2000)
    conc = pm.HalfNormal('conc', sigma = 50)
    alpha = pm.Dirichlet('alpha', a = conc * alpha_global, shape = (R,G))

    # likelihood spike ins
    s_seq_obs = pm.Poisson("s_seq_obs", mu = s_spike * b_s, observed = s_seq)

    # likelihood transcriptome
    #x_mrna = X_mrna * alpha
    x_mrna = X_mrna[:, None] * alpha 
    x_seq_obs = pm.Poisson("x_seq_obs", mu = x_mrna * b_x, observed = x_seq)

    trace = pm.sample(1000, tune = 1000, target_accept = 0.90)
    prior_pred = pm.sample_prior_predictive(samples = 1000)
    ppc = pm.sample_posterior_predictive(trace)
    
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu_b, sigma_b, b_x, b_s, alpha_global, X_mrna, conc, alpha]
/opt/anaconda3/envs/pymc_env/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 129 seconds.
Sampling: [X_mrna, alpha, alpha_global, b_s, b_x, conc, mu_b, s_seq_obs, sigma_b, x_seq_obs]
Sampling: [s_seq_obs, x_seq_obs]
/opt/anaconda3/envs/pymc_env/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

InĀ [4]:
az.summary(trace)
Out[4]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
X_mrna[0] 47208.093 2009.271 43536.579 51029.575 32.723 29.571 3767.0 3034.0 1.0
X_mrna[1] 39095.431 1990.006 35426.596 43012.302 31.249 32.452 4060.0 2790.0 1.0
mu_b 0.153 0.118 0.000 0.358 0.002 0.002 1922.0 1677.0 1.0
sigma_b 0.606 0.167 0.350 0.934 0.005 0.004 1248.0 1721.0 1.0
b_x[0, 0] 1.234 0.892 0.165 2.530 0.017 0.080 3105.0 2641.0 1.0
b_x[0, 1] 1.347 1.051 0.156 2.985 0.020 0.099 3250.0 2467.0 1.0
b_x[0, 2] 1.336 1.014 0.117 2.901 0.019 0.049 3113.0 2432.0 1.0
b_x[0, 3] 2.387 1.359 0.614 4.617 0.028 0.059 2060.0 1843.0 1.0
b_x[0, 4] 1.179 0.796 0.237 2.460 0.015 0.035 3007.0 2146.0 1.0
b_x[0, 5] 1.374 1.394 0.109 3.065 0.032 0.240 3058.0 2493.0 1.0
b_x[0, 6] 0.907 0.072 0.786 1.033 0.002 0.002 2238.0 2014.0 1.0
b_x[1, 0] 1.285 1.141 0.176 2.690 0.029 0.141 2450.0 2106.0 1.0
b_x[1, 1] 1.344 1.158 0.126 2.906 0.023 0.147 3154.0 2304.0 1.0
b_x[1, 2] 1.304 0.962 0.109 2.792 0.018 0.037 3056.0 2218.0 1.0
b_x[1, 3] 0.802 0.572 0.075 1.754 0.012 0.027 1870.0 1332.0 1.0
b_x[1, 4] 1.287 0.824 0.236 2.603 0.016 0.034 2901.0 2357.0 1.0
b_x[1, 5] 1.382 1.373 0.163 3.001 0.031 0.238 3197.0 2429.0 1.0
b_x[1, 6] 1.042 0.081 0.902 1.181 0.002 0.003 1890.0 1631.0 1.0
b_s[0, 0] 1.619 0.679 0.541 2.875 0.009 0.013 6073.0 2976.0 1.0
b_s[0, 1] 0.766 0.181 0.450 1.117 0.003 0.003 5071.0 3065.0 1.0
b_s[0, 2] 0.879 0.067 0.755 1.005 0.001 0.001 4850.0 2989.0 1.0
b_s[0, 3] 1.438 0.027 1.388 1.487 0.000 0.000 5573.0 2695.0 1.0
b_s[0, 4] 1.284 0.008 1.269 1.298 0.000 0.000 5330.0 2900.0 1.0
b_s[1, 0] 1.893 0.776 0.595 3.319 0.012 0.015 4209.0 3001.0 1.0
b_s[1, 1] 0.686 0.167 0.387 1.012 0.002 0.003 4458.0 2425.0 1.0
b_s[1, 2] 0.709 0.060 0.601 0.827 0.001 0.001 5152.0 2950.0 1.0
b_s[1, 3] 1.999 0.032 1.937 2.057 0.000 0.001 5326.0 2838.0 1.0
b_s[1, 4] 1.350 0.008 1.335 1.366 0.000 0.000 5124.0 2817.0 1.0
alpha_global[0] 0.017 0.012 0.001 0.038 0.000 0.000 3284.0 2642.0 1.0
alpha_global[1] 0.003 0.003 0.000 0.007 0.000 0.000 3333.0 2559.0 1.0
alpha_global[2] 0.003 0.003 0.000 0.008 0.000 0.000 3017.0 2376.0 1.0
alpha_global[3] 0.070 0.042 0.012 0.142 0.001 0.002 1790.0 1434.0 1.0
alpha_global[4] 0.025 0.016 0.002 0.052 0.000 0.001 3152.0 2941.0 1.0
alpha_global[5] 0.003 0.003 0.000 0.007 0.000 0.000 3491.0 2860.0 1.0
alpha_global[6] 0.878 0.049 0.792 0.955 0.001 0.002 2061.0 1673.0 1.0
conc 77.411 33.040 21.391 140.421 0.584 0.456 2850.0 2489.0 1.0
alpha[0, 0] 0.012 0.008 0.002 0.026 0.000 0.000 3097.0 2530.0 1.0
alpha[0, 1] 0.000 0.000 0.000 0.000 0.000 0.000 2661.0 2299.0 1.0
alpha[0, 2] 0.000 0.000 0.000 0.001 0.000 0.000 2830.0 2414.0 1.0
alpha[0, 3] 0.100 0.048 0.025 0.188 0.001 0.002 2030.0 1799.0 1.0
alpha[0, 4] 0.019 0.012 0.003 0.039 0.000 0.001 3030.0 2087.0 1.0
alpha[0, 5] 0.000 0.000 0.000 0.000 0.000 0.000 2799.0 2741.0 1.0
alpha[0, 6] 0.868 0.050 0.780 0.953 0.001 0.002 2039.0 1621.0 1.0
alpha[1, 0] 0.013 0.009 0.002 0.028 0.000 0.000 2439.0 2141.0 1.0
alpha[1, 1] 0.000 0.000 0.000 0.000 0.000 0.000 3001.0 2549.0 1.0
alpha[1, 2] 0.000 0.000 0.000 0.000 0.000 0.000 3259.0 2757.0 1.0
alpha[1, 3] 0.049 0.039 0.007 0.112 0.001 0.002 1828.0 1333.0 1.0
alpha[1, 4] 0.024 0.014 0.004 0.048 0.000 0.001 2868.0 2246.0 1.0
alpha[1, 5] 0.000 0.000 0.000 0.000 0.000 0.000 3093.0 2642.0 1.0
alpha[1, 6] 0.913 0.043 0.839 0.969 0.001 0.002 1691.0 1317.0 1.0

Prior predictive checks¶

InĀ [5]:
prior_samples = prior_pred.prior_predictive['x_seq_obs'].values
plots = []

# Loop through each replicate (2 reps)
for rep in range(R):
    for i in range(G):
        # for one gene
        gene_ind = i
        prior_gene = prior_samples[0, :, rep, gene_ind]  # get prior samples for this replicate and gene
        x_seq_gene = x_seq[rep, gene_ind]  # observed count for this replicate/gene

        # set up histogram
        p1 = figure(
            title=f"Prior Predictive Check: Rep {rep+1}, Gene {gene_ind}",
            x_axis_label="x_seq_obs",
            y_axis_label="Count",
            width=400,
            height=300,
            background_fill_color="#fafafa"
        )

        hist, edges = np.histogram(prior_gene, bins=50)

        p1.quad(
            top=hist,
            bottom=0,
            left=edges[:-1],
            right=edges[1:],
            fill_color="skyblue",
            line_color="white",
            legend_label="1000 random samples"
        )

        # Add observed data line
        p1.line(
            x=(x_seq_gene, x_seq_gene),
            y=(0, 600), #hist.max()),
            width=3,
            color='black',
            line_dash='dotted'
        )

        p1.x_range.end = 1000# x_seq_gene + 1000

        plots.append(p1)
n_cols = 4
# Split into rows
grid = [plots[i:i+n_cols] for i in range(0, len(plots), n_cols)]
# Display
show(gridplot(grid))

Posterior predictive check - gene counts¶

InĀ [6]:
post_samples = ppc.posterior_predictive['x_seq_obs'].values
post_samples_comb = post_samples.reshape(-1, post_samples.shape[2], post_samples.shape[3])
plots = []

# Loop through each replicate (2 reps)
for rep in range(R):
    for i in range(G):
        # for one gene
        gene_ind = i
        post_gene = post_samples_comb[:, rep, gene_ind]  # get posterior samples for this replicate and gene
        x_seq_gene = x_seq[rep, gene_ind]  # observed count for this replicate/gene

        # set up histogram
        p1 = figure(
            title=f"Posterior Predictive Check: Rep {rep+1}, Gene {gene_ind}",
            x_axis_label="x_seq_obs",
            y_axis_label="Count",
            width=400,
            height=300,
            background_fill_color="#fafafa"
        )

        hist, edges = np.histogram(prior_gene, bins=50)

        p1.quad(
            top=hist,
            bottom=0,
            left=edges[:-1],
            right=edges[1:],
            fill_color="skyblue",
            line_color="white",
            legend_label="1000 random samples"
        )

        # Add observed data line
        p1.line(
            x=(x_seq_gene, x_seq_gene),
            y=(0, 600), #hist.max()),
            width=3,
            color='black',
            line_dash='dotted'
        )

        #p1.x_range.end = x_seq_gene + 1000

        plots.append(p1)

n_cols = 4
# Split into rows
grid = [plots[i:i+n_cols] for i in range(0, len(plots), n_cols)]
# Display
show(gridplot(grid))

Posterior predictive check - alpha¶

InĀ [7]:
post_alpha = trace.posterior['alpha'].values
post_alpha_comb = post_alpha.reshape(-1, post_alpha.shape[2], post_alpha.shape[3])

plots = []

for rep in range(R):
    for i in range(G):
# for one gene
        gene_ind = i
        post_gene = post_alpha_comb[:, rep, gene_ind]  # get posterior samples for this replicate and gene
        alpha_gene = betas[rep, gene_ind]  # observed count for this replicate/gene

        
        # set up histogram
        
        p1 = figure(title="posterior Predictive Check: alpha of gene " + str(gene_ind),
                     x_axis_label="alpha",
                     y_axis_label="Count",
                    width = 400, height = 300,
                     background_fill_color="#fafafa")
        
        
        hist, edges = np.histogram(post_gene, bins=100)
        
        p1.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:],
                 fill_color="skyblue", line_color="white",
                 legend_label="1000 random samples")
        
        p1.line(x = (alpha_gene, alpha_gene), y = (0,hist.max()), width = 3, color = 'black', line_dash = 'dotted')
    
        plots.append(p1)

n_cols = 4
# Split into rows
grid = [plots[i:i+n_cols] for i in range(0, len(plots), n_cols)]
# Display
show(gridplot(grid))

Posterior predictive check - alpha global¶

InĀ [8]:
post_alpha_global = trace.posterior['alpha_global'].values  # shape (chains, draws, G)
post_alpha_global_comb = post_alpha_global.reshape(-1, post_alpha_global.shape[-1])  # (chains*draws, G)

plots = []

for i in range(G):
    post_gene = post_alpha_global_comb[:, i]  # posterior samples for gene i
    #alpha_gene = betas_mean[:,i]  # observed/reference alpha_global for gene i
    alpha_gene = betas[:,i]
    
    p1 = figure(
        title=f"Posterior Predictive Check: alpha_global gene {i}",
        x_axis_label="alpha_global",
        y_axis_label="Count",
        width=400,
        height=300,
        background_fill_color="#fafafa"
    )
    
    hist, edges = np.histogram(post_gene, bins=100)
    
    p1.quad(
        top=hist,
        bottom=0,
        left=edges[:-1],
        right=edges[1:],
        fill_color="skyblue",
        line_color="white",
        legend_label="Posterior samples"
    )
    
    p1.line(
        x=(alpha_gene[0], alpha_gene[0]),
        y=(0, hist.max()),
        width=3,
        color='black',
        line_dash='dotted',
        legend_label='Observed alpha_global'
    )

    p1.line(
        x=(alpha_gene[1], alpha_gene[1]),
        y=(0, hist.max()),
        width=3,
        color='black',
        line_dash='dotted',
        legend_label='Observed alpha_global'
    )
    
    plots.append(p1)

# Optionally, show some plots
show(gridplot(plots, ncols=3))

Assess model quality¶

posterior traces¶

InĀ [9]:
for r in range(R):
    for g in range(G):
        var_name = f"alpha[{r},{g}]"
        # Extract samples for this replicate and gene
        alpha_rg = trace.posterior['alpha'].isel(alpha_dim_0=r, alpha_dim_1=g).to_dataset(name='alpha_rg')
        
        # Create a copy of the trace object with just this variable
        idata_rg = trace.copy()
        idata_rg.posterior = alpha_rg

        # Plot the trace
        az.plot_trace(idata_rg, var_names=['alpha_rg'])
        plt.suptitle(f"Trace plot for alpha[{r},{g}] (Rep {r+1}, Gene {g})")
        plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
InĀ [10]:
for g in range(G):
    var_name = f"alpha_global[{g}]"
    # Extract samples for this gene
    alpha_global_g = trace.posterior['alpha_global'].isel(alpha_global_dim_0=g).to_dataset(name='alpha_global_g')
    
    # Create a copy of the trace object with just this variable
    idata_g = trace.copy()
    idata_g.posterior = alpha_global_g

    # Plot the trace
    az.plot_trace(idata_g, var_names=['alpha_global_g'])
    plt.suptitle(f"Trace plot for alpha_global species {g}")
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
InĀ [11]:
az.plot_trace(trace, var_names = ['mu_b', 'sigma_b'])
Out[11]:
array([[<Axes: title={'center': 'mu_b'}>,
        <Axes: title={'center': 'mu_b'}>],
       [<Axes: title={'center': 'sigma_b'}>,
        <Axes: title={'center': 'sigma_b'}>]], dtype=object)
No description has been provided for this image
InĀ [12]:
az.plot_trace(trace, var_names = ['X_mrna', 'conc'])
Out[12]:
array([[<Axes: title={'center': 'X_mrna'}>,
        <Axes: title={'center': 'X_mrna'}>],
       [<Axes: title={'center': 'conc'}>,
        <Axes: title={'center': 'conc'}>]], dtype=object)
No description has been provided for this image
InĀ [13]:
for r in range(R):
    for g in range(G):
        var_name = f"b_x[{r},{g}]"
        b_x_i = trace.posterior['b_x'].isel(b_x_dim_0=r, b_x_dim_1 = g).to_dataset(name='b_x_i')
    
        idata_i = trace.copy()
        idata_i.posterior = b_x_i
    
        az.plot_trace(idata_i, var_names=['b_x_i'])
        #plt.suptitle(f"Trace plot for b_x species {i}")
        plt.suptitle(f"Trace plot for b_x[{r},{g}] (Rep {r+1}, Gene {g})")

        plt.show()
    
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
InĀ [14]:
for r in range(R):
    for k in range(K):
        var_name = f"b_s[{r},{k}]"
        b_s_i = trace.posterior['b_s'].isel(b_s_dim_0=r, b_s_dim_1 = k).to_dataset(name='b_s_i')
    
        idata_i = trace.copy()
        idata_i.posterior = b_s_i
    
        az.plot_trace(idata_i, var_names=['b_s_i'])
        #plt.suptitle(f"Trace plot for b_x species {i}")
        plt.suptitle(f"Trace plot for b_s[{r},{k}] (Rep {r+1}, spike-in species {k})")

        plt.show()
    
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Nonidentifiability checks¶

X_mrna vs b_x¶

InĀ [15]:
ncols = 4
nrows = (G + ncols - 1) // ncols  # ceiling division

for rep in range(R):  # loop through replicates
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6*ncols, 3*nrows))
    axes = axes.flatten()
    fig.suptitle(f"Replicate {rep}", fontsize=16)

    for g in range(G):
        b_x_i = trace.posterior['b_x'].isel(b_x_dim_0=rep, b_x_dim_1 = g)
        X_mrna_i = trace.posterior['X_mrna'].isel(X_mrna_dim_0=rep)

        az.plot_pair(
            {'X_mrna[{}]'.format(rep): X_mrna_i, 'b_x[{},{}]'.format(rep, g): b_x_i},
            kind='kde',
            ax=axes[g]
        )

    # Hide any unused axes
    for j in range(G, len(axes)):
        fig.delaxes(axes[j])

    #plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
No description has been provided for this image
No description has been provided for this image

X_mrna vs alpha¶

InĀ [16]:
ncols = 4
nrows = (G + ncols - 1) // ncols  # ceiling division

for rep in range(R):  # loop through replicates
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6*ncols, 3*nrows))
    axes = axes.flatten()
    fig.suptitle(f"Replicate {rep}", fontsize=16)

    for g in range(G):
        alpha_i = trace.posterior['alpha'].isel(alpha_dim_0=rep, alpha_dim_1 = g)
        X_mrna_i = trace.posterior['X_mrna'].isel(X_mrna_dim_0=rep)

        az.plot_pair(
            {'X_mrna[{}]'.format(rep): X_mrna_i, 'alpha[{},{}]'.format(rep, g): alpha_i},
            kind='kde',
            ax=axes[g]
        )

    # Hide any unused axes
    for j in range(G, len(axes)):
        fig.delaxes(axes[j])

    #plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
No description has been provided for this image
No description has been provided for this image

alpha vs b_x¶

InĀ [17]:
#import matplotlib.pyplot as plt
#import arviz as az

ncols = 4
nrows = (G + ncols - 1) // ncols  # ceiling division

for rep in range(R):  # loop through replicates
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6*ncols, 3*nrows))
    axes = axes.flatten()
    fig.suptitle(f"Replicate {rep}", fontsize=16)

    for g in range(G):
        b_x_i = trace.posterior['b_x'].isel(b_x_dim_0=rep, b_x_dim_1 = g)
        alpha_i = trace.posterior['alpha'].isel(alpha_dim_0=rep, alpha_dim_1=g)

        az.plot_pair(
            {'b_x[{},{}]'.format(rep, g): b_x_i, 'alpha[{},{}]'.format(rep, g): alpha_i},
            kind='kde',
            ax=axes[g]
        )

    # Hide any unused axes
    for j in range(G, len(axes)):
        fig.delaxes(axes[j])

    #plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
No description has been provided for this image
No description has been provided for this image

alpha vs b_s¶

InĀ [18]:
#import matplotlib.pyplot as plt
#import arviz as az

ncols = 4
nrows = (K + ncols - 1) // ncols  # ceiling division

for rep in range(R):  # loop through replicates
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6*ncols, 3*nrows))
    axes = axes.flatten()
    fig.suptitle(f"Replicate {rep}", fontsize=16)

    for k in range(K):
        b_s_i = trace.posterior['b_s'].isel(b_s_dim_0=rep, b_s_dim_1 = k)
        alpha_i = trace.posterior['alpha'].isel(alpha_dim_0=rep, alpha_dim_1=k)

        az.plot_pair(
            {'b_s[{},{}]'.format(rep, g): b_s_i, 'alpha[{},{}]'.format(rep, k): alpha_i},
            kind='kde',
            ax=axes[k]
        )

    # Hide any unused axes
    for j in range(K, len(axes)):
        fig.delaxes(axes[j])

    #plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
No description has been provided for this image
No description has been provided for this image
InĀ [19]:
az.plot_pair(trace, var_names = [ 'X_mrna', 'b_s',], kind = 'kde')
/opt/anaconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/plots/backends/matplotlib/pairplot.py:233: UserWarning: rcParams['plot.max_subplots'] (40) is smaller than the number of resulting pair plots with these variables, generating only a 8x8 grid
  warnings.warn(
Out[19]:
array([[<Axes: ylabel='X_mrna\n1'>, <Axes: >, <Axes: >, <Axes: >,
        <Axes: >, <Axes: >, <Axes: >, <Axes: >],
       [<Axes: ylabel='b_s\n0, 0'>, <Axes: >, <Axes: >, <Axes: >,
        <Axes: >, <Axes: >, <Axes: >, <Axes: >],
       [<Axes: ylabel='b_s\n0, 1'>, <Axes: >, <Axes: >, <Axes: >,
        <Axes: >, <Axes: >, <Axes: >, <Axes: >],
       [<Axes: ylabel='b_s\n0, 2'>, <Axes: >, <Axes: >, <Axes: >,
        <Axes: >, <Axes: >, <Axes: >, <Axes: >],
       [<Axes: ylabel='b_s\n0, 3'>, <Axes: >, <Axes: >, <Axes: >,
        <Axes: >, <Axes: >, <Axes: >, <Axes: >],
       [<Axes: ylabel='b_s\n0, 4'>, <Axes: >, <Axes: >, <Axes: >,
        <Axes: >, <Axes: >, <Axes: >, <Axes: >],
       [<Axes: ylabel='b_s\n1, 0'>, <Axes: >, <Axes: >, <Axes: >,
        <Axes: >, <Axes: >, <Axes: >, <Axes: >],
       [<Axes: xlabel='X_mrna\n0', ylabel='b_s\n1, 1'>,
        <Axes: xlabel='X_mrna\n1'>, <Axes: xlabel='b_s\n0, 0'>,
        <Axes: xlabel='b_s\n0, 1'>, <Axes: xlabel='b_s\n0, 2'>,
        <Axes: xlabel='b_s\n0, 3'>, <Axes: xlabel='b_s\n0, 4'>,
        <Axes: xlabel='b_s\n1, 0'>]], dtype=object)
No description has been provided for this image
InĀ [20]:
%load_ext watermark
%watermark -v -p numpy,bokeh,pymc,arviz,bebi103,iqplot,pandas,matplotlib,colorcet,tqdm,jupyterlab
/opt/anaconda3/envs/pymc_env/lib/python3.12/site-packages/bebi103/viz.py:38: UserWarning: DataShader import failed with error "No module named 'datashader'".
Features requiring DataShader will not work and you will get exceptions.
  warnings.warn(
/opt/anaconda3/envs/pymc_env/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Python implementation: CPython
Python version       : 3.12.8
IPython version      : 8.30.0

numpy     : 1.26.4
bokeh     : 3.7.3
pymc      : 5.19.1
arviz     : 0.21.0
bebi103   : 0.1.26
iqplot    : 0.3.7
pandas    : 2.2.3
matplotlib: 3.9.4
colorcet  : 3.1.0
tqdm      : 4.67.1
jupyterlab: 4.3.3